import jax
import scalevi.objectives.objectives as objectives


def get_objective(model, var_dist, config_dict):
	# TODO: Change this to get attribute method
	if config_dict['objective']=='ELBO':
		_train = objectives.elbo
	elif config_dict['objective'] == "Total ELBO":
		_train = objectives.total_elbo
	elif config_dict['objective'] == "CFE ELBO":
		_train = objectives.cfe_elbo
	else:
		raise NotImplementedError("Other objectives not implemented")
	

	if config_dict['objective'] == "CFE ELBO":
		_obj_train = jax.partial(
						_train, 
						sample_and_log_q = var_dist.sample_and_log_prob)
	else:
		_obj_train = jax.partial(
						_train, 
						log_q = var_dist.log_prob,
						sample_q = var_dist.sample)

	_obj_train = jax.partial(_obj_train, log_p = model.log_prob)
		
	obj_train = jax.partial(objectives._multiple_obj_copies, obj=_obj_train, 
                            num_copies=config_dict.get('num_copies_train', 1), 
                            agg_func='mean')

	obj_eval = jax.partial(
						objectives.eval_elbo,
						log_p = model.log_prob, 
						sample_and_log_q = var_dist.sample_and_log_prob,
						n_chunk = config_dict['N_leaves'],
						num_copies_eval = config_dict.get("num_copies_eval", 10),
						minibatch_use = config_dict['minibatch_use'])

	obj_test = jax.partial(
						objectives.eval_test_ll,
						log_p = jax.partial(model.eval_child_ll, use_test = True), 
						sample_and_log_q = jax.partial(var_dist.sample_and_log_prob, use_test = False),
						n_chunk = config_dict['N_leaves'],
						num_copies_eval = config_dict.get("num_copies_eval", 10),
						minibatch_use = config_dict['minibatch_use'])


	return obj_train, obj_eval, obj_test

def get_objective_temp(model, var_dist, config_dict):
	# TODO: Change this to get attribute method
	if config_dict['objective']=='ELBO':
		_train = objectives.elbo
	elif config_dict['objective'] == "Total ELBO":
		_train = objectives.total_elbo
	elif config_dict['objective'] == "CFE ELBO":
		_train = objectives.cfe_elbo
	else:
		raise NotImplementedError("Other objectives not implemented")
	

	obj_eval_train = jax.partial(
						objectives.eval_train_ll,
						log_p = model.log_prob, 
						# log_p = jax.partial(model.eval_child_ll, use_test = False), 
						sample_and_log_q = jax.partial(var_dist.sample_and_log_prob, use_test = False),
						n_chunk = config_dict['N_leaves'],
						num_copies_eval = config_dict.get("num_copies_eval", 10),
						minibatch_use = config_dict['minibatch_use'])

	obj_eval_test = jax.partial(
						objectives.eval_test_ll,
						log_p = jax.partial(model.eval_child_ll, use_test = True), 
						sample_and_log_q = jax.partial(var_dist.sample_and_log_prob, use_test = False),
						n_chunk = config_dict['N_leaves'],
						num_copies_eval = config_dict.get("num_copies_eval", 10),
						minibatch_use = config_dict['minibatch_use'])

	return obj_eval_train, obj_eval_test

def get_objective_temp_alternate(model, var_dist, config_dict):
	# TODO: Change this to get attribute method
	if config_dict['objective']=='ELBO':
		_train = objectives.elbo
	elif config_dict['objective'] == "Total ELBO":
		_train = objectives.total_elbo
	elif config_dict['objective'] == "CFE ELBO":
		_train = objectives.cfe_elbo
	else:
		raise NotImplementedError("Other objectives not implemented")
	

	obj_eval_train = jax.partial(
						objectives.eval_mean_ll,
						log_p = jax.partial(model.eval_child_mean_ll, use_test = False), 
						sample_and_log_q = jax.partial(var_dist.sample_and_log_prob, use_test = False),
						n_chunk = config_dict['N_leaves'],
						num_copies_eval = config_dict.get("num_copies_eval", 10),
						minibatch_use = config_dict['minibatch_use'])

	obj_eval_test = jax.partial(
						objectives.eval_mean_ll,
						log_p = jax.partial(model.eval_child_mean_ll, use_test = True), 
						sample_and_log_q = jax.partial(var_dist.sample_and_log_prob, use_test = False),
						n_chunk = config_dict['N_leaves'],
						num_copies_eval = config_dict.get("num_copies_eval", 10),
						minibatch_use = config_dict['minibatch_use'])

	return obj_eval_train, obj_eval_test

# def get_final_evaluation_obj(
# 							optimized_params,
# 							model,
# 							var_dist):

# 	_obj_eval = jax.partial(
# 					objectives.elbo, log_p = model.log_prob, 
# 					log_q = var_dist.log_prob,
# 					sample_q = var_dist.sample)

# 	obj_eval = jax.partial(
# 					objectives._multiple_obj_copies,
# 					var_params = optimized_params,
# 					rng_key=jax.random.PRNGKey(0),
# 					iter = 0,
# 					minibatch=None, 
# 					obj=_obj_eval, 
# 					agg_func='mean')
# 	return lambda x: obj_eval(num_copies = x)
# 	# return obj_eval
	